# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import gzip
import os
import pickle
import struct
from typing import Tuple

import numpy as np
from tqdm import tqdm

from ....logger import get_logger
from .meta_vision import VisionDataset
from .utils import _default_dataset_root, load_raw_data_from_url

logger = get_logger(__name__)


class MNIST(VisionDataset):
    r""" ``Dataset`` for MNIST meta data
    """

    url_path = "http://yann.lecun.com/exdb/mnist/"
    """
    url prefix for downloading raw file
    """
    raw_file_name = [
        "train-images-idx3-ubyte.gz",
        "train-labels-idx1-ubyte.gz",
        "t10k-images-idx3-ubyte.gz",
        "t10k-labels-idx1-ubyte.gz",
    ]
    """
    raw file names of both training set and test set (10k)
    """
    raw_file_md5 = [
        "f68b3c2dcbeaaa9fbdd348bbdeb94873",
        "d53e105ee54ea40749a09fcbcd1e9432",
        "9fb629c4189551a2d022fa330f9573f3",
        "ec29112dd5afa0611ce80d1b7f02629c",
    ]
    """
    md5 for checking raw files
    """
    train_file = "train.pkl"
    """
    default pickle file name of training set and its meta data
    """
    test_file = "test.pkl"
    """
    default pickle file name of test set and its meta data
    """

    def __init__(
        self,
        root: str = None,
        train: bool = True,
        download: bool = True,
        timeout: int = 500,
    ):
        r"""
        initialization:

        1. check root path and target file (train or test)
        2. check target file exists

           * if exists:

             * load pickle file as meta-data and data in MNIST dataset

           * else:

             * if download:

               a. load all raw datas (both train and test set) by url
               b. process raw data ( idx3/idx1 -> dict (meta-data) ,numpy.array (data) )
               c. save meta-data and data as pickle file
               d. load pickle file as meta-data and data in MNIST dataset

        :param root: path for mnist dataset downloading or loading, if ``None``,
            set ``root`` to the ``_default_root``
        :param train: if ``True``, loading trainingset, else loading test set
        :param download: after checking the target files existence, if target files do not
            exists and download sets to ``True``, download raw files and process,
            then load, otherwise raise ValueError, default is True

        """
        super().__init__(root, order=("image", "image_category"))

        self.timeout = timeout

        # process the root path
        if root is None:
            self.root = self._default_root
            if not os.path.exists(self.root):
                os.makedirs(self.root)
        else:
            self.root = root
            if not os.path.exists(self.root):
                raise ValueError("dir %s does not exist" % self.root)

        # choose the target pickle file
        if train:
            self.target_file = os.path.join(self.root, self.train_file)
        else:
            self.target_file = os.path.join(self.root, self.test_file)

        # check existence of target pickle file, if exists load the
        # pickle file no matter what download is set
        if os.path.exists(self.target_file):
            self._meta_data, self.arrays = self._load_file(self.target_file)
        elif self._check_raw_files():
            self.process()
            self._meta_data, self.arrays = self._load_file(self.target_file)
        else:
            if download:
                self.download()
                self._meta_data, self.arrays = self._load_file(self.target_file)
            else:
                raise ValueError(
                    "dir does not contain target file\
                        %s,please set download=True"
                    % (self.target_file)
                )

    def __getitem__(self, index: int) -> Tuple:
        return tuple(array[index] for array in self.arrays)

    def __len__(self) -> int:
        return len(self.arrays[0])

    @property
    def _default_root(self):
        return os.path.join(_default_dataset_root(), self.__class__.__name__)

    @property
    def meta(self):
        return self._meta_data

    def _load_file(self, target_file):
        with open(target_file, "rb") as f:
            return pickle.load(f)

    def _check_raw_files(self):
        return all(
            [
                os.path.exists(os.path.join(self.root, path))
                for path in self.raw_file_name
            ]
        )

    def download(self):
        for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5):
            url = self.url_path + file_name
            load_raw_data_from_url(url, file_name, md5, self.root, self.timeout)
        self.process()

    def process(self):
        # load raw files and transform them into meta data and datasets Tuple(np.array)
        logger.info("process raw data ...")
        meta_data_images_train, images_train = parse_idx3(
            os.path.join(self.root, self.raw_file_name[0])
        )
        meta_data_labels_train, labels_train = parse_idx1(
            os.path.join(self.root, self.raw_file_name[1])
        )
        meta_data_images_test, images_test = parse_idx3(
            os.path.join(self.root, self.raw_file_name[2])
        )
        meta_data_labels_test, labels_test = parse_idx1(
            os.path.join(self.root, self.raw_file_name[3])
        )

        meta_data_train = {
            "images": meta_data_images_train,
            "labels": meta_data_labels_train,
        }
        meta_data_test = {
            "images": meta_data_images_test,
            "labels": meta_data_labels_test,
        }
        dataset_train = (images_train, labels_train)
        dataset_test = (images_test, labels_test)

        # save both training set and test set as pickle files
        with open(os.path.join(self.root, self.train_file), "wb") as f:
            pickle.dump((meta_data_train, dataset_train), f, pickle.HIGHEST_PROTOCOL)
        with open(os.path.join(self.root, self.test_file), "wb") as f:
            pickle.dump((meta_data_test, dataset_test), f, pickle.HIGHEST_PROTOCOL)


def parse_idx3(idx3_file):
    # parse idx3 file to meta data and data in numpy array (images)
    logger.debug("parse idx3 file %s ..." % idx3_file)
    assert idx3_file.endswith(".gz")
    with gzip.open(idx3_file, "rb") as f:
        bin_data = f.read()

    #  parse meta data
    offset = 0
    fmt_header = ">iiii"
    magic, imgs, height, width = struct.unpack_from(fmt_header, bin_data, offset)
    meta_data = {"magic": magic, "imgs": imgs, "height": height, "width": width}

    # parse images
    image_size = height * width
    offset += struct.calcsize(fmt_header)
    fmt_image = ">" + str(image_size) + "B"
    images = []
    bar = tqdm(total=meta_data["imgs"], ncols=80)
    for image in struct.iter_unpack(fmt_image, bin_data[offset:]):
        images.append(np.array(image, dtype=np.uint8).reshape((height, width, 1)))
        bar.update()
    bar.close()
    return meta_data, images


def parse_idx1(idx1_file):
    # parse idx1 file to meta data and data in numpy array (labels)
    logger.debug("parse idx1 file %s ..." % idx1_file)
    assert idx1_file.endswith(".gz")
    with gzip.open(idx1_file, "rb") as f:
        bin_data = f.read()

    # parse meta data
    offset = 0
    fmt_header = ">ii"
    magic, imgs = struct.unpack_from(fmt_header, bin_data, offset)
    meta_data = {"magic": magic, "imgs": imgs}

    # parse labels
    offset += struct.calcsize(fmt_header)
    fmt_image = ">B"
    labels = np.empty(imgs, dtype=int)
    bar = tqdm(total=meta_data["imgs"], ncols=80)
    for i, label in enumerate(struct.iter_unpack(fmt_image, bin_data[offset:])):
        labels[i] = label[0]
        bar.update()
    bar.close()
    return meta_data, labels
